{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "962d87bb",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "# Getting Started\n",
    "\n",
    "## Overview\n",
    "\n",
    "Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n",
    "\n",
    "This guide shows how to start using Transformer Engine with JAX. Similar tutorial for pyTorch is available [here](quickstart.ipynb).\n",
    "We recommend you to try understanding the basics of JAX first, using these resources:\n",
    "\n",
    "- Thinking in JAX: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html\n",
    "- JAX 101: https://docs.jax.dev/en/latest/jax-101.html\n",
    "- Key concepts in JAX: https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array\n",
    "- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html\n",
    "\n",
    "## Let's build a Transformer decoder layer!\n",
    "<small>_This is based upon the GPT decoder layer with causal masking, which prevents each position from attending to future positions._</small>\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Summary</b>\n",
    "    \n",
    "We build a basic Transformer layer using regular Flax modules. This will be our baseline for later comparisons with Transformer Engine.\n",
    "\n",
    "</div>\n",
    "\n",
    "Let's start with creating the transformer layer using plain [FLAX Linen](https://flax.readthedocs.io/en/stable/) . Figure 1 shows the overall structure.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"transformer_layer.png\" width=\"20%\">\n",
    "<figcaption> Figure 1: Structure of a GPT decoder layer.</figcaption>\n",
    "</figure>\n",
    "\n",
    "We construct the components as follows:\n",
    "\n",
    "- `LayerNorm`: `nn.LayerNorm` (Flax)\n",
    "- `QKV Projection`: `nn.Dense` (conceptually there are three seperate `Dense` layers for Q, K, and V separately, but we fuse them together into a single `Dense` layer that is three times larger)\n",
    "- `DotProductAttention`: `nn.MuliheadDotProductAttention` (Flax)\n",
    "- `Projection`: `nn.Dense` (Flax)\n",
    "- `Dropout`: `nn.Dropout` (Flax)\n",
    "- `MLP`: `FlaxMLP` implemented using `nn.Dense` and `nn.gelu`\n",
    "\n",
    "Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_jax_utils.py](quickstart_jax_utils.py). Putting it all together:  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d5284a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from flax import linen as nn\n",
    "import quickstart_jax_utils as utils\n",
    "from typing import Optional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a4d1cfdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FlaxMLP(nn.Module):\n",
    "    \"\"\"Feed-forward network in Transformer layer\n",
    "    Built with plain Flax modules.\n",
    "    \"\"\"\n",
    "    hidden_size: int\n",
    "    ffn_hidden_size: int\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n",
    "        x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)\n",
    "        x = nn.gelu(x, approximate=True)  # equivalent to tanh approximation\n",
    "        x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n",
    "        return x\n",
    "\n",
    "class FlaxTransformerLayer(nn.Module):\n",
    "    \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n",
    "    hidden_size: int\n",
    "    ffn_hidden_size: int\n",
    "    num_attention_heads: int\n",
    "    layernorm_eps: float = 1e-5\n",
    "    attention_dropout: float = 0.1\n",
    "    \n",
    "    def setup(self):\n",
    "        self.kv_channels = self.hidden_size // self.num_attention_heads\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(\n",
    "        self, \n",
    "        x: jnp.ndarray, \n",
    "        attention_mask: Optional[jnp.ndarray] = None,\n",
    "        deterministic: bool = False\n",
    "    ) -> jnp.ndarray:\n",
    "        # Create causal mask if not provided\n",
    "        if attention_mask is None:\n",
    "            attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
    "        \n",
    "        res = x\n",
    "        x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
    "        \n",
    "        # Fused QKV projection\n",
    "        qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)\n",
    "        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
    "        q, k, v = jnp.split(qkv, 3, axis=3)\n",
    "        \n",
    "        # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
    "        # which is the correct format for dot_product_attention\n",
    "        \n",
    "        # Apply dot product attention\n",
    "        # Note: dot_product_attention expects mask to be broadcastable to \n",
    "        # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
    "        # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
    "        \n",
    "        # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
    "        dropout_rng = None\n",
    "        if not deterministic and self.attention_dropout > 0:\n",
    "            dropout_rng = self.make_rng('dropout')\n",
    "        \n",
    "        x = nn.dot_product_attention(\n",
    "            query=q,\n",
    "            key=k,\n",
    "            value=v,\n",
    "            mask=attention_mask,\n",
    "            dropout_rng=dropout_rng,\n",
    "            dropout_rate=self.attention_dropout,\n",
    "            deterministic=deterministic,\n",
    "            broadcast_dropout=True,\n",
    "        )\n",
    "        \n",
    "        # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
    "        x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
    "        \n",
    "        x = res + x\n",
    "        \n",
    "        # Second residual connection\n",
    "        res = x\n",
    "        x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
    "        \n",
    "        # MLP\n",
    "        mlp = FlaxMLP(\n",
    "            hidden_size=self.hidden_size,\n",
    "            ffn_hidden_size=self.ffn_hidden_size,\n",
    "        )\n",
    "        x = mlp(x)\n",
    "        \n",
    "        return x + res\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbc3510b",
   "metadata": {},
   "source": [
    "## Testing Performance\n",
    "\n",
    "Now let's test the performance of our FlaxTransformerLayer:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "8b44649d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Layer configuration\n",
    "hidden_size = 4096\n",
    "sequence_length = 2048\n",
    "batch_size = 4\n",
    "ffn_hidden_size = 16384\n",
    "num_attention_heads = 32\n",
    "dtype = jnp.bfloat16\n",
    "\n",
    "# Synthetic data\n",
    "key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n",
    "x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n",
    "dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e44ed26d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pure Flax FlaxTransformerLayer initialized successfully!\n",
      "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
     ]
    }
   ],
   "source": [
    "# Initialize the FlaxTransformerLayer\n",
    "flax_transformer = FlaxTransformerLayer(\n",
    "    hidden_size=hidden_size,\n",
    "    ffn_hidden_size=ffn_hidden_size,\n",
    "    num_attention_heads=num_attention_heads,\n",
    ")\n",
    "\n",
    "# Initialize parameters\n",
    "params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
    "\n",
    "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
    "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "de91af7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input shape: (4, 2048, 4096)\n",
      "Output shape: (4, 2048, 4096)\n",
      "Output dtype: float32\n",
      "Forward pass completed successfully!\n"
     ]
    }
   ],
   "source": [
    "# Example usage of forward pass\n",
    "y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n",
    "print(f\"Input shape: {x.shape}\")\n",
    "print(f\"Output shape: {y.shape}\")\n",
    "print(f\"Output dtype: {y.dtype}\")\n",
    "print(\"Forward pass completed successfully!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "037bc8d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean time: 18.546080589294434 ms\n"
     ]
    }
   ],
   "source": [
    "import importlib\n",
    "import quickstart_jax_utils\n",
    "importlib.reload(quickstart_jax_utils)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=flax_transformer.apply,\n",
    "    variables=params,\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccb16f31",
   "metadata": {},
   "source": [
    "## Meet Transformer Engine\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Summary</b>\n",
    "    \n",
    "Now that we have a basic Transformer layer in Flax, let's use Transformer Engine to speed up the training. The following examples show how to use TE modules.\n",
    "\n",
    "</div>\n",
    "\n",
    "As a reminder, the FlaxTransformerLayer above used:\n",
    "\n",
    "- `nn.LayerNorm`: Flax LayerNorm\n",
    "- `nn.Dense`: Flax Dense layer for QKV projection  \n",
    "- `nn.MultiheadDotProductAttention`: Flax MultiheadDotProductAttention\n",
    "- `nn.Dense`: Flax Dense layer for projection\n",
    "- `nn.Dropout`: Flax Dropout\n",
    "- `FlaxMLP`: Custom MLP implemented from `nn.Dense`\n",
    "\n",
    "Below we show how to use Transformer Engine Flax modules for better performance:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "bed20d6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformer_engine.jax as te\n",
    "import transformer_engine.jax.flax as te_flax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f28cb444",
   "metadata": {},
   "source": [
    "TE provides a set of Flax Linen modules that can be used to build Transformer layers. The simplest of the provided modules are the `DenseGeneral ` and `LayerNorm` layers, which we can use instead of `flax.linen.Dense` and ` flax.linen.LayerNorm`. Let's modify our `FlaxTransformerLayer`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "56105579",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention\n",
    "\n",
    "\n",
    "class TEUnfusedMLP(nn.Module):\n",
    "    hidden_size : int\n",
    "    ffn_hidden_size: int\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:\n",
    "        x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True) (x)\n",
    "        x = x.reshape(*x.shape[:-1], 1, x.shape[-1])\n",
    "        x = te.activation.activation(x, activation_type=('gelu',))\n",
    "        x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True) (x)\n",
    "        return x\n",
    "\n",
    "class TEUnfusedTransformerLayer(nn.Module):\n",
    "    hidden_size: int\n",
    "    ffn_hidden_size: int \n",
    "    num_attention_heads: int  \n",
    "    layernorm_eps: float = 1e-5\n",
    "    attention_dropout: float = 0.1 \n",
    "    use_te_attention: bool = True  # True for TE attention, False for Flax attention\n",
    "\n",
    "    def setup(self):\n",
    "        self.kv_channels = self.hidden_size // self.num_attention_heads\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(\n",
    "        self, \n",
    "        x: jnp.ndarray,\n",
    "        attention_mask: Optional[jnp.ndarray] = None,\n",
    "        deterministic: bool = False\n",
    "    ) -> jnp.ndarray:\n",
    "        # Create causal mask if not provided\n",
    "        if attention_mask is None:\n",
    "            attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
    "        \n",
    "        res = x\n",
    "        x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
    "\n",
    "        # Fused QKV projection\n",
    "        qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)\n",
    "        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
    "        q, k, v = jnp.split(qkv, 3, axis=3)\n",
    "\n",
    "        # Attention - either TE or Flax implementation\n",
    "        if self.use_te_attention:\n",
    "            # Use TE's DotProductAttention\n",
    "            attention = TEDotProductAttention(\n",
    "                head_dim=self.kv_channels,\n",
    "                num_attention_heads=self.num_attention_heads,\n",
    "                num_gqa_groups=self.num_attention_heads,  # No GQA\n",
    "                attention_dropout=self.attention_dropout,\n",
    "                attn_mask_type='causal',\n",
    "            )\n",
    "            x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
    "            # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
    "            x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
    "            x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
    "            x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
    "        else:\n",
    "            # Use Flax's MultiHeadDotProductAttention\n",
    "            q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n",
    "            k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
    "            v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
    "            \n",
    "            attention = nn.MultiHeadDotProductAttention(\n",
    "                num_heads=self.num_attention_heads,\n",
    "                qkv_features=self.kv_channels,\n",
    "                dropout_rate=self.attention_dropout,\n",
    "            )\n",
    "            x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n",
    "\n",
    "        x = res + x\n",
    "\n",
    "        # Second residual connection\n",
    "        res = x\n",
    "        x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
    "\n",
    "        # MLP\n",
    "        mlp = TEUnfusedMLP(\n",
    "            hidden_size=self.hidden_size,\n",
    "            ffn_hidden_size=self.ffn_hidden_size\n",
    "        )\n",
    "\n",
    "        x = mlp(x, deterministic=deterministic)\n",
    "\n",
    "        return x + res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a76911ac",
   "metadata": {},
   "source": [
    "Testing performance of the model, using `DenseGeneral`, `LayerNorm` and activation from TE, while keeping Flax's `MultiHeadDotProductAttention` the same as the first simple Transformer in JAX implementation. To read more about this implementation from Flax, you can refer to this documentation:  https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "4b67511f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean time: 16.375374794006348 ms\n"
     ]
    }
   ],
   "source": [
    "te_unfused_transformer_with_flax_MHA = TEUnfusedTransformerLayer(\n",
    "    hidden_size, \n",
    "    ffn_hidden_size, \n",
    "    num_attention_heads,\n",
    "    use_te_attention=False\n",
    ")\n",
    "\n",
    "te_params = te_unfused_transformer_with_flax_MHA.init(key, x, attention_mask=None, deterministic=False)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=te_unfused_transformer_with_flax_MHA.apply,\n",
    "    variables=te_params,  # Ensure the correct `params` is passed\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b230058",
   "metadata": {},
   "source": [
    "Now, we move on to also replace the attention sub-layer with TE's `DotProductAttention` implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "5146cd99",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:634: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
      "  warnings.warn(\n",
      "/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
      "Fall back to the unfused attention.\n",
      "Please try to update the cuDNN and TE to the latest version.\n",
      "self.dtype=<class 'jax.numpy.float32'>\n",
      "qkv_layout=<QKVLayout.BSHD_BSHD_BSHD: <NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: 9>>\n",
      "attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
      "attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
      "self.attention_dropout=0.1\n",
      "self.num_attention_heads=32\n",
      "self.num_gqa_groups=32\n",
      "seqlen_q=2048\n",
      "seqlen_kv=2048\n",
      "head_dim_qk=128\n",
      "head_dim_v=128\n",
      "\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean time: 12.403340339660645 ms\n"
     ]
    }
   ],
   "source": [
    "te_unfused_transformer = TEUnfusedTransformerLayer(\n",
    "    hidden_size, \n",
    "    ffn_hidden_size, \n",
    "    num_attention_heads,\n",
    ")\n",
    "\n",
    "te_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=te_unfused_transformer.apply,\n",
    "    variables=te_params,  # Ensure the correct `params` is passed\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9a101d3",
   "metadata": {},
   "source": [
    "## Enabling Quantization (FP8 or FP4)\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Summary</b>\n",
    "    \n",
    "We configure a TE module to perform compute in FP8.\n",
    "\n",
    "</div>\n",
    "\n",
    "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.\n",
    "\n",
    "<div class=\"alert alert-warning\">\n",
    "\n",
    "<b>Important: FP8 Metadata Initialization</b>\n",
    "\n",
    "When using FP8, the model **must be initialized within the `autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "c2eee376",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_engine.common.recipe import Format, DelayedScaling\n",
    "fp8_format = Format.HYBRID\n",
    "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "de96827c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean time: 9.396424293518066 ms\n"
     ]
    }
   ],
   "source": [
    "with te.autocast(enabled=True, recipe=fp8_recipe):\n",
    "    te_unfused_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
    "\n",
    "    # Example usage of forward \n",
    "    y = te_unfused_transformer.apply(te_unfused_params, x, attention_mask=None, deterministic=True)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=te_unfused_transformer.apply,\n",
    "    variables=te_unfused_params,  # Ensure the correct `params` is passed\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    "    autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3801b201",
   "metadata": {},
   "source": [
    "\n",
    "## Fused TE Modules\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Summary</b>\n",
    "    \n",
    "We optimize the example Transformer layer with TE modules for fused operations.\n",
    "\n",
    "</div>\n",
    "\n",
    "The `DenseGeneral` layer is enough to build any Transformer model and it enables usage of the Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations such as kernel fusions in mixed-precision recipes, increasing the achievable speedup.\n",
    "\n",
    "Transformer Engine therefore provides coarser modules that span multiple layers:\n",
    "\n",
    "* `LayerNormDenseGeneral`\n",
    "* `LayerNormMLP`\n",
    "* `TransformerLayer`\n",
    "\n",
    "To see a complete list of all the functions TE Flax support, you can view it here: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#modules\n",
    "\n",
    "Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "11203785",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TEFusedTransformerLayer(nn.Module):\n",
    "    hidden_size: int\n",
    "    ffn_hidden_size: int \n",
    "    num_attention_heads: int  \n",
    "    layernorm_eps: float = 1e-5\n",
    "    attention_dropout: float = 0.1\n",
    "\n",
    "    def setup(self):\n",
    "        self.kv_channels = self.hidden_size // self.num_attention_heads\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(\n",
    "        self, \n",
    "        x: jnp.ndarray,\n",
    "        attention_mask: Optional[jnp.ndarray] = None,\n",
    "        deterministic: bool = False\n",
    "    ) -> jnp.ndarray:\n",
    "        res = x\n",
    "\n",
    "         # Fused QKV projection\n",
    "        qkv,_ = te_flax.LayerNormDenseGeneral(features=3 * self.hidden_size, \n",
    "                                              epsilon=self.layernorm_eps, \n",
    "                                              use_bias=True, \n",
    "                                              return_layernorm_output=False)(x)\n",
    "        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
    "        q, k, v = jnp.split(qkv, 3, axis=3)\n",
    "\n",
    "        # Attention using TE's DotProductAttention\n",
    "        attention = TEDotProductAttention(\n",
    "            head_dim=self.kv_channels,\n",
    "            num_attention_heads=self.num_attention_heads,\n",
    "            num_gqa_groups=self.num_attention_heads,  \n",
    "            attention_dropout=self.attention_dropout,\n",
    "            attn_mask_type='causal',\n",
    "        )\n",
    "        x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
    "        # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
    "        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
    "        x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
    "        x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
    "\n",
    "        x = res + x\n",
    "\n",
    "        # Second residual connection\n",
    "        res = x\n",
    "        x,_ = te_flax.LayerNormMLP(intermediate_dim=self.ffn_hidden_size, \n",
    "                                 epsilon=self.layernorm_eps,\n",
    "                                 use_bias=True,\n",
    "                                 activations=('gelu',),\n",
    "                                 intermediate_dropout_rate=0.0,\n",
    "                                 return_layernorm_output=False\n",
    "                                 )(x, deterministic=deterministic)\n",
    "\n",
    "        return x + res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "334cff59",
   "metadata": {},
   "source": [
    "Similar to the unnfused model, we also compare the performance of fused model when using Flax's MultiheadDotProductAttention implementation and TE's."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "6b0c705e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean time: 9.145426750183105 ms\n"
     ]
    }
   ],
   "source": [
    "te_fused_transformer = TEFusedTransformerLayer(\n",
    "    hidden_size, \n",
    "    ffn_hidden_size, \n",
    "    num_attention_heads\n",
    ")\n",
    "\n",
    "with te.autocast(enabled=True, recipe=fp8_recipe):\n",
    "    te_fused_params = te_fused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
    "    # Example usage of forward \n",
    "    y = te_fused_transformer.apply(te_fused_params, x, attention_mask=None, deterministic=True)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=te_fused_transformer.apply,\n",
    "    variables=te_fused_params,\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    "    autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a45c12c8",
   "metadata": {},
   "source": [
    "Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "b2aaa8ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
      "Fall back to the unfused attention.\n",
      "Please try to update the cuDNN and TE to the latest version.\n",
      "self.dtype=<class 'jax.numpy.float32'>\n",
      "qkv_layout=<QKVLayout.BS3HD: <NVTE_QKV_Layout.NVTE_BS3HD: 5>>\n",
      "attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
      "attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
      "self.attention_dropout=0.1\n",
      "self.num_attention_heads=32\n",
      "self.num_gqa_groups=32\n",
      "seqlen_q=2048\n",
      "seqlen_kv=2048\n",
      "head_dim_qk=128\n",
      "head_dim_v=128\n",
      "\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "\n",
    "te_transformer = te_flax.TransformerLayer(\n",
    "    hidden_size=hidden_size,\n",
    "    mlp_hidden_size=ffn_hidden_size, \n",
    "    num_attention_heads=num_attention_heads,\n",
    "    mlp_activations=(\"gelu\",),\n",
    "    self_attn_mask_type='causal',\n",
    "    layernorm_epsilon=1e-5,\n",
    "    use_bias=True,\n",
    "    intermediate_dropout=0.0,\n",
    "    enable_relative_embedding=False,\n",
    "    self_attn_bias_type='no_bias',\n",
    "    hidden_dropout=0.0\n",
    ")\n",
    "\n",
    "with te.autocast(enabled=True, recipe=fp8_recipe):\n",
    "    te_transformer_params = te_transformer.init(key, x, deterministic=False)\n",
    "    y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "b9cdbf22",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean time: 9.020795822143555 ms\n"
     ]
    }
   ],
   "source": [
    "utils.speedometer(\n",
    "    model_apply_fn=te_transformer.apply,\n",
    "    model_init_fn=te_transformer.init,\n",
    "    variables=te_transformer_params,\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    "    autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe }\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
